# core/formalization/rl/exp.py
import torch
import numpy as np
from typing import Dict, List, Tuple, Any, Optional
from collections import deque
import random
from dataclasses import dataclass


@dataclass
class StepExp:
    state: np.ndarray
    action: int
    log_prob: float
    value: float
    exp_reward: float
    done: bool
    mask: np.ndarray
    reward_components: Dict[str, float]

@dataclass
class TrajectoryBatch:
    states: torch.Tensor
    actions: torch.Tensor
    old_log_probs: torch.Tensor
    values: torch.Tensor
    exp_rewards: torch.Tensor
    advantages: torch.Tensor
    returns: torch.Tensor
    masks: torch.Tensor
    reward_components: torch.Tensor  # [batch_size, 4] for rs, re, rh, rd

class ExpBuffer:

    def __init__(self, capacity: int = 10000):
        self.capacity = capacity
        self.buffer = deque(maxlen=capacity)
        self.trajectories = []

    def add_exp(self, exp: StepExp):
        self.buffer.append(exp)

    def get_size(self):
        return len(self.buffer)
    
    def add_trajectory(self, trajectory: List[StepExp]):
        self.trajectories.append(trajectory)
        if len(self.trajectories) > self.capacity // 10:
            self.trajectories.pop(0)

    def compute_gae(self, trajectory: List[StepExp], discount_factor: float = 0.99, smooth_factor: float = 0.95) -> List[float]:

        if not trajectory:
            return []
        advantages = []
        gae = 0

        for i in reversed(range(len(trajectory))):
            exp = trajectory[i]
            
            if i == len(trajectory) - 1:
                next_value = 0 if exp.done else 0
                delta = exp.exp_reward + discount_factor * next_value - exp.value
            else:
                next_exp = trajectory[i + 1]
                delta = exp.exp_reward + discount_factor * next_exp.value - exp.value
            
            gae = delta + discount_factor * smooth_factor * gae
            advantages.insert(0, gae)

        return advantages
    
    def get_training_batch(self, batch_size: int) -> Optional[TrajectoryBatch]:
        if len(self.trajectories) == 0:
            return None
        
        all_experiences = []
        all_advantages = []
        
        for trajectory in self.trajectories:
            advantages = self.compute_gae(trajectory)
            all_experiences.extend(trajectory)
            all_advantages.extend(advantages)
        
        if len(all_experiences) < batch_size:
            return None
        
        indices = random.sample(range(len(all_experiences)), min(batch_size, len(all_experiences)))
        
        states = []
        actions = []
        old_log_probs = []
        values = []
        exp_rewards = []
        advantages_batch = []
        masks = []
        reward_components = []
        
        for idx in indices:
            exp: StepExp = all_experiences[idx]
            adv = all_advantages[idx]
            
            states.append(exp.state)
            actions.append(exp.action)
            old_log_probs.append(exp.log_prob)
            values.append(exp.value)
            exp_rewards.append(exp.exp_reward)
            advantages_batch.append(adv)
            masks.append(exp.mask)
            
            components = [
                exp.reward_components.get('rs', 0.0),
                exp.reward_components.get('re', 0.0),
                exp.reward_components.get('rh', 0.0),
                exp.reward_components.get('rd', 0.0)
            ]
            reward_components.append(components)

        
        returns = [adv + val for adv, val in zip(advantages_batch, values)]
        
        batch = TrajectoryBatch(
            states=torch.FloatTensor(np.array(states)),
            actions=torch.LongTensor(actions),
            old_log_probs=torch.FloatTensor(old_log_probs),
            values=torch.FloatTensor(values),
            exp_rewards=torch.FloatTensor(exp_rewards),
            advantages=torch.FloatTensor(advantages_batch),
            returns=torch.FloatTensor(returns),
            masks=torch.FloatTensor(np.array(masks)),
            reward_components=torch.FloatTensor(reward_components),
        )
        
        return batch
    
    def clear(self):
        self.buffer.clear()
        self.trajectories.clear()